import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from bball_data.utils import unnormalize, plot_sequence

def num_trainable_params(model):
    total = 0
    for p in model.parameters():
        count = 1
        for s in p.size():
            count *= s
        total += count
    return total


def run_epoch(train, model, exp_data, clip, optimizer=None, batch_size=64):
    losses = []
    inds = np.random.permutation(exp_data.shape[0])
    
    i = 0
    while i + batch_size <= exp_data.shape[0]:
        ind = torch.from_numpy(inds[i:i+batch_size]).long()
        i += batch_size
        data = exp_data[ind]
    
        # change (batch, time, x) to (time, batch, x)
        data = data.transpose(0, 1)
        
        batch_loss = model(data)

        if train:
            optimizer.zero_grad()
            total_loss = batch_loss
            total_loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()
        
        losses.append(batch_loss.data.cpu().numpy())

    return np.mean(losses)

def pretrain_dis_iter(G, D, D_opt, exp_data, clip, batch_size, criterion=nn.BCELoss()):
    inds = np.random.permutation(exp_data.shape[0])
    device = exp_data.device

    real_labels = torch.ones((exp_data.shape[1] - 1, batch_size, 1)).to(device)
    fake_labels = torch.zeros((exp_data.shape[1] - 1, batch_size, 1)).to(device)
    #real_labels = torch.ones(batch_size).to(device)
    #fake_labels = torch.zeros(batch_size).to(device)

    exp_ind = torch.from_numpy(np.random.choice(exp_data.shape[0], batch_size)).long()
    data = exp_data[exp_ind].clone().transpose(0, 1) # change (batch, time, x) to (time, batch, x)
    fake_data = G.generate(batch_size, data.size(0), device=device)

    # Discriminate real vs fake
    real_probs = D(data)
    fake_probs = D(fake_data)
    
    #print(real_probs.shape, fake_probs.shape)
    
    real_probs_val = real_probs.data.cpu().mean()
    fake_probs_val = fake_probs.data.cpu().mean()
    print('real:', real_probs_val, 'fake:', fake_probs_val)

    # Update D
    D_opt.zero_grad()
    loss = criterion(real_probs, real_labels) + criterion(fake_probs, fake_labels)
    loss.backward()
    nn.utils.clip_grad_norm_(D.parameters(), clip)
    D_opt.step()
    
    return real_probs_val, fake_probs_val

def train_gan_iter(G, D, G_opt, D_opt, exp_data, clip, batch_size, criterion=nn.BCELoss()):
    inds = np.random.permutation(exp_data.shape[0])
    device = exp_data.device

    real_labels = torch.ones((exp_data.shape[1] - 1, batch_size, 1)).to(device)
    fake_labels = torch.zeros((exp_data.shape[1] - 1, batch_size, 1)).to(device)
    #real_labels = torch.ones(batch_size).unsqueeze(-1).to(device)
    #fake_labels = torch.zeros(batch_size).unsqueeze(-1).to(device)

    exp_ind = torch.from_numpy(np.random.choice(exp_data.shape[0], batch_size)).long()
    data = exp_data[exp_ind].clone().transpose(0, 1) # change (batch, time, x) to (time, batch, x)
    fake_data = G.generate(batch_size, data.size(0), device=device)

    # Discriminate real vs fake
    real_probs = D(data)
    fake_probs = D(fake_data)
    
    real_probs_val = real_probs.data.cpu().mean()
    fake_probs_val = fake_probs.data.cpu().mean()
    print('real:', real_probs_val, 'fake:', fake_probs_val)

    # Update D
    D_opt.zero_grad()
    loss = criterion(real_probs, real_labels) + criterion(fake_probs, fake_labels)
    loss.backward(retain_graph=True)
    nn.utils.clip_grad_norm_(D.parameters(), clip)
    D_opt.step()

    # Update G
    G_opt.zero_grad()
    loss = criterion(fake_probs, real_labels)
    loss.backward()
    nn.utils.clip_grad_norm_(G.parameters(), clip)
    G_opt.step()
    
    return real_probs_val, fake_probs_val


def impute(G, D, data, mask, n_iters=1, lambd=0.0, lr=1):
    T, batch_size, _ = data.size()
    assert mask.size(0) == T
    assert mask.size(1) == batch_size

    z = torch.randn(batch_size, G.z_dim, requires_grad=True)

    for i in range(n_iters):
        samples = G.decode(z, T)

        reconst_loss = F.mse_loss(data*mask, samples*mask)
        discrim_loss = -torch.mean(D(samples))
  
        loss = reconst_loss + lambd*discrim_loss

        if (i+1) % 10 == 0:
            if lambd > 0:
                print("LOSS [{}/{}] | reconst {:4f} | discrim {:4f} | total {:4f}".format(
                    i+1, n_iters, reconst_loss.item(), discrim_loss.item(), loss.item()))
            else:
                print("LOSS [{}/{}] | reconst {:4f}".format(i+1, n_iters, reconst_loss.item()))

        if i == n_iters * 4 // 5:
            lr = lr / 10

        loss.backward()

        with torch.no_grad():
            z -= lr * z.grad
            z.grad.zero_()

    samples = G.decode(z, T)
    imputed = data*mask + (1-mask)*samples

    return imputed.detach(), mask.detach()


# draw and compute statistics
def draw_and_stats(model_states, name, i_iter, compute_stats=True, draw=True, missing_list=None):
    stats = {}
    if compute_stats:
        model_actions = model_states[1:, :, :] - model_states[:-1, :, :]
            
        val_data = model_states.cpu().numpy()
        val_actions = model_actions.cpu().numpy()
    
        step_size = np.sqrt(np.square(val_actions[:, :, ::2]) + np.square(val_actions[:, :, 1::2]))
        change_of_step_size = np.abs(step_size[1:, :, :] - step_size[:-1, :, :])
        stats['ave_change_step_size'] = np.mean(change_of_step_size)
        val_seqlength = np.sum(np.sqrt(np.square(val_actions[:, :, ::2]) + np.square(val_actions[:, :, 1::2])), axis = 0)
        stats['ave_length'] = np.mean(val_seqlength)  ## when sum along axis 0, axis 1 becomes axis 0
        stats['ave_out_of_bound'] = np.mean((val_data < -0.51) + (val_data > 0.51))
    
    if draw:
        print("Drawing")
        draw_data = model_states.cpu().numpy()[:, 0, :] 
        draw_data = unnormalize(draw_data)
        colormap = ['b', 'r', 'g', 'm', 'y']
        plot_sequence(draw_data, macro_goals=None, colormap=colormap[:5], \
                      save_name="imgs/{}_{}_offense".format(name, i_iter), missing_list=missing_list)

    return stats